""" Merge the contract and well datasets together

"""
import pandas as pd
import pandasql as ps
import numpy as np

#%%
keep_columns = [
    'rig_name',
    'type',
    'day_rate',
    'fixture_date',
    'start',
    'end',
    'spud_date',
    'depth_date',
    'mri',
    'spec',
    'max_wd',
    'api',
    'contract_id',
    'bus_asc_name',
    'company_name',
    'operator',
    'contractor',
    'duration',
    'impute_n_strings',
    'bid',
    'oil_volume',
    'gas_volume',
    'water_depth',
    'turnkey',
    'surf_latitude',
    'surf_longitude'
]

collapse_columns = [
    'rig_name',
    'type',
    'day_rate',
    'fixture_date',
    'start',
    'end',
    'max_wd',
    'contract_id',
    'contractor',
    'operator',
    'spec',
    'turnkey',
]

collapse_columns_keep = [
    'mri',
    'bid',
    'oil_volume',
    'gas_volume',
    'water_depth',
    'duration',
    'surf_latitude',
    'surf_longitude'
]

extend_types = [
    'Mutual reneg',
    'Mutual option',
    'Priced option'
]

for deepening in ['_no_deepening', '']:
    #%% READ IN DATA --------------------------------------------------------------------
    contracts_with_map = pd.read_csv(
        './data_py/temp/05_merge_with_map/contracts_with_map.csv',
        index_col=[0],
        parse_dates=['start', 'end']
    )
    contracts_with_map['contract_id'] = list(range(len(contracts_with_map)))

    wells_with_map = pd.read_csv(
        './data_py/temp/05_merge_with_map/wells_with_map.csv',
        index_col=[0],
        parse_dates=['depth_date', 'spud_date']
    )
    if deepening == '_no_deepening':
        wells_with_map = wells_with_map[wells_with_map['deepened_candidate'] == False]

    #%% SPLIT INTO CLASSES --------------------------------------------------------------
    contracts_with_map['spec'] = np.nan
    contracts_with_map.loc[contracts_with_map['max_wd'] <= 200, 'spec'] = 'low'
    contracts_with_map.loc[
        (contracts_with_map['max_wd'] > 200) & (contracts_with_map['max_wd'] < 300), 'spec'] = 'mid'
    contracts_with_map.loc[contracts_with_map['max_wd'] >= 300, 'spec'] = 'high'

    #%% MERGE ON START DATE -------------------------------------------------------------
    # i.e. well spudded in the date range of the contract
    sqlcode = '''
        select *
        from contracts_with_map
        left join wells_with_map on contracts_with_map.name_in_ihs=wells_with_map.name_in_ihs
        where contracts_with_map.end >= wells_with_map.spud_date 
            and wells_with_map.spud_date >= contracts_with_map.start
    '''

    df_merged_espud = ps.sqldf(sqlcode, locals())[keep_columns]

    #%% MERGE ON END DATE ---------------------------------------------------------------
    sqlcode = '''
        select *
        from contracts_with_map
        left join wells_with_map on contracts_with_map.name_in_ihs=wells_with_map.name_in_ihs
        where contracts_with_map.end >= wells_with_map.depth_date 
            and wells_with_map.depth_date >= contracts_with_map.start
    '''

    df_merged_eddate = ps.sqldf(sqlcode, locals())[keep_columns]

    #%% MERGE FOR ANY CONTRACTS WITHIN A WELL RANGE -------------------------------------
    sqlcode = '''
        select *
        from contracts_with_map
        left join wells_with_map on contracts_with_map.name_in_ihs=wells_with_map.name_in_ihs
        where contracts_with_map.end <= wells_with_map.depth_date 
            and wells_with_map.spud_date <= contracts_with_map.start
    '''

    df_merged_between = ps.sqldf(sqlcode, locals())[keep_columns]

    #%% COMBINE AND COLLAPSE ------------------------------------------------------------
    df_all = df_merged_espud.append(df_merged_eddate).append(df_merged_between)
    df_all = df_all.loc[:, ~df_all.columns.duplicated()]

    #%% Merge on contract which has the most overlap in days
    for j in ['start', 'end', 'depth_date', 'spud_date']:
        df_all[j] = pd.to_datetime(df_all[j])

    df_all['start_overlap'] = np.maximum(df_all['spud_date'], df_all['start'])
    df_all['end_overlap'] = np.minimum(df_all['depth_date'], df_all['end'])
    df_all['overlap'] = (df_all['end_overlap'] - df_all['start_overlap']).dt.days
    df_all['max_overlap'] = df_all.groupby(['api'])['overlap'].transform(max)

    # Keep if overlap longer than 2 weeks or entire well is drilled under one contract
    df_all = df_all[(
        (df_all['overlap'] >= 14) | (df_all['max_overlap'] == df_all['overlap'])
    )]

    #%% Find lat/long of first/last well drilled under the contract
    df_wells = df_all.drop_duplicates(subset=['api', 'contract_id'])
    # df_wells = df_wells.sort_values(['contract_id', 'spud_date'])

    df_collapse = (
        df_wells.groupby(
            by=collapse_columns,
            as_index=False
        )
        [collapse_columns_keep]
        .mean()
    )
    df_collapse['merge'] = 1

    if deepening == '':
        df_wells.to_csv('./data_py/temp/06_merge_contracts_wells/wells_merged_no_impute.csv')
        df_collapse.to_csv(
            './data_py/temp/06_merge_contracts_wells/contracts_collapse_no_impute.csv')

    #%% GET CONTRACTS AND WELLS THAT ARE NOT MERGE --------------------------------------
    wells_no_merge = wells_with_map[~wells_with_map['api'].isin(df_all['api'].unique())]
    contracts_no_merge = contracts_with_map[
        ~contracts_with_map['contract_id'].isin(df_collapse['contract_id'].unique())
    ]
    contracts_no_merge['merge'] = 0

    #%% IMPUTATION ----------------------------------------------------------------------
    contracts_all = df_collapse.append(contracts_no_merge)
    contracts_all = contracts_all.sort_values(['rig_name', 'start'])
    for metric in ['mri', 'rig_name', 'operator']:
        contracts_all[f'{metric}_prev'] = contracts_all[metric].shift(1)
    for metric in ['mri', 'rig_name', 'operator']:
        contracts_all[f'{metric}_next'] = contracts_all[metric].shift(-1)

    contracts_all['mri_impute'] = 0
    contracts_all['bid_impute'] = 0
    contracts_all['duration_impute'] = 0

    def impute(contracts_all, metric):
        mask = (
                (contracts_all['rig_name'] == contracts_all['rig_name_prev'])
                & (contracts_all['operator'] == contracts_all['operator_prev'])
                & (contracts_all[metric].isna())
        )
        mask_next = (
                (contracts_all['rig_name'] == contracts_all['rig_name_next'])
                & (contracts_all['operator'] == contracts_all['operator_next'])
                & (contracts_all[metric].isna())
        )

        for i in range(100):
            contracts_all.loc[mask, metric] = contracts_all[metric].shift(1)[mask]
            contracts_all.loc[mask, f'{metric}_impute'] = 1

        for i in range(100):
            contracts_all.loc[mask_next, metric] = contracts_all[metric].shift(-1)[
                mask_next]
            contracts_all.loc[mask_next, f'{metric}_impute'] = 1

        return contracts_all

    for metric in ['mri', 'water_depth', 'bid', 'duration', 'surf_latitude', 'surf_longitude']:
        contracts_all = impute(contracts_all, metric)

    #%% DROP IF MISSING MRI -------------------------------------------------------------
    contracts_final = contracts_all.dropna(subset=['mri'])
    contracts_final = contracts_final[
        collapse_columns + collapse_columns_keep
        + ['mri_impute', 'bid_impute', 'duration_impute']
    ]

    #%% ASSIGN DURATIONS ----------------------------------------------------------------
    contracts_final['tau'] = np.nan
    contracts_final.loc[contracts_final['duration'] < 75, 'tau'] = 2
    contracts_final.loc[
        (contracts_final['duration'] >= 75)
        & (contracts_final['duration'] < 105),
        'tau'] = 3
    contracts_final.loc[(contracts_final['duration'] >= 105), 'tau'] = 4

    #%% ASSIGN RENEG --------------------------------------------------------------------
    contracts_final['reneg'] = 1 * (contracts_final['type'].isin(extend_types))

    #%% SAVE ----------------------------------------------------------------------------
    contracts_final.to_csv(f'./data_py/temp/06_merge_contracts_wells/contracts_merged{deepening}.csv')
    # contracts_final.to_csv('./data_py/processed/contracts_final.csv')
